#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Created on Mon Jan  7 20:25:13 2019

@author: DELL
"""
import numpy as np
import svm1_callbacks
import alldata
from sklearn.utils import shuffle
from sklearn.model_selection import train_test_split
"""
对所有数据进行label，根据svm所检测到的异常点
"""
cycle_label=[]
for i in range(len(alldata.engine)):
    cycle_label.append([])
    for j in range(1,len(alldata.engine_cycles[i])+1):
        if alldata.engine_cycles[i][j-1]<=svm1_callbacks.result[i]:
            cycle_label[i].append(len(alldata.engine_cycles[i])-svm1_callbacks.result[i])
        else:
            cycle_label[i].append(len(alldata.engine_cycles[i])-j)
"""
数据处理转换为lstm的输入格式[batch，timestep，input_dim]
train_x，train_y，train_c分别对应(x,10,24),(x,10,1),(x,10,1)的维度
然后拼接在一起
"""
train_x,train_y,train_c=[],[],[]
timestep=31
for i in range(len(alldata.engine)):
    train_x.append([])
    train_y.append([])
    train_c.append([])
    for j in range(len(alldata.engine[i])-timestep+1):
        x=np.array(alldata.engine[i])[j:j+timestep,:]
        y=np.array(cycle_label[i])[j:j+timestep,np.newaxis]
        c=np.array(alldata.engine_cycles[i])[j:j+timestep,np.newaxis]
        train_x[i].append(x.tolist())
        train_y[i].append(y)
        train_c[i].append(c)
    #print(np.array(train_x[i]).shape,np.array(train_y[i]).shape,np.array(train_c[i]).shape)   
trainx=train_x[0]
trainy=train_y[0]
trainc=train_c[0]
for i in range(1,len(alldata.engine)):
    trainx=np.concatenate((trainx,train_x[i]),axis=0)
    trainy=np.concatenate((trainy,train_y[i]),axis=0)
    trainc=np.concatenate((trainc,train_c[i]),axis=0)
#print(trainx.shape,trainy.shape,trainc.shape)
trainy_=trainy[:,-1,:]
trainc_=trainc[:,-1,:]
shuffle(trainx,trainy_)
x_train,x_validate,y_train,y_validate=train_test_split(trainx,trainy_,test_size=0.3)
#print(x_train.shape,x_validate.shape,y_train.shape,y_validate.shape)